{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Finetuning corpus embeddings using NUDGE\n",
    "[NUDGE](https://www.arxiv.org/abs/2409.02343) is a novel simple and lightweight fine-tuning method that boosts accuracy when retrieving text using semantic similarity with pre-trained embedding models. NUDGE directly modifies the embeddings of data records to maximize the similarity between training queries and their ground-truth answers. NUDGE does so non-parametrically. Non-parametric means that NUDGE does not modify model parameters to generate better embeddings, as fine-tuning the embedding model, or training adaptors would. Instead, NUDGE directly changes the embeddings themselves. Compared with fine-tuning the pre-trained model and training adaptors, NUDGE provides 3.3x and 4.3x higher increase in accuracy and runs 200x and 3x faster, respectively. [Here](https://data-people-group.github.io/blogs/2024/09/05/nudge/) is a blog post on NUDGE, and [here](https://www.arxiv.org/abs/2409.02343) is the paper with more details.\n",
    "\n",
    "We demonstrate NUDGE's effectiveness on a commonly used Information Retrieval benchmark called Scifact."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install llama-index-experimental llama-index-embeddings-huggingface nudge-ft torch datasets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the scifact benchmark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:datasets:PyTorch version 2.5.0a0+872d972e41.nv24.8 available.\n",
      "PyTorch version 2.5.0a0+872d972e41.nv24.8 available.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from llama_index.finetuning import EmbeddingQAFinetuneDataset\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "def load_hf_dataset(dataset_name):\n",
    "    hf_dataset_name = f\"sepz/{dataset_name}_ft\"\n",
    "    corpus = load_dataset(hf_dataset_name, \"data_records\", split=\"train\")\n",
    "\n",
    "    queries_train = load_dataset(hf_dataset_name, \"qs\", split=\"train\")\n",
    "    queries_validation = load_dataset(hf_dataset_name, \"qs\", split=\"dev\")\n",
    "    queries_test = load_dataset(hf_dataset_name, \"qs\", split=\"test\")\n",
    "\n",
    "    qrels_train = load_dataset(hf_dataset_name, \"qs_rel\", split=\"train\")\n",
    "    qrels_validation = load_dataset(hf_dataset_name, \"qs_rel\", split=\"dev\")\n",
    "    qrels_test = load_dataset(hf_dataset_name, \"qs_rel\", split=\"test\")\n",
    "\n",
    "    corpus = {\n",
    "        str(corpus[i][\"record_id\"]): corpus[i][\"text\"]\n",
    "        for i in range(len(corpus))\n",
    "    }\n",
    "\n",
    "    queries_train = {\n",
    "        str(queries_train[i][\"q_id\"]): queries_train[i][\"input\"]\n",
    "        for i in range(len(queries_train))\n",
    "    }\n",
    "    queries_validation = {\n",
    "        str(r[\"q_id\"]): r[\"input\"] for r in queries_validation\n",
    "    }\n",
    "    queries_test = {str(r[\"q_id\"]): r[\"input\"] for r in queries_test}\n",
    "\n",
    "    qrels_train = (\n",
    "        qrels_train.to_pandas()\n",
    "        .groupby(\"q_id\")[\"record_id\"]\n",
    "        .apply(list)\n",
    "        .to_dict()\n",
    "    )\n",
    "    qrels_validation = (\n",
    "        qrels_validation.to_pandas()\n",
    "        .groupby(\"q_id\")[\"record_id\"]\n",
    "        .apply(list)\n",
    "        .to_dict()\n",
    "    )\n",
    "    qrels_test = (\n",
    "        qrels_test.to_pandas()\n",
    "        .groupby(\"q_id\")[\"record_id\"]\n",
    "        .apply(list)\n",
    "        .to_dict()\n",
    "    )\n",
    "    # convert to strings\n",
    "    qrels_train = {str(k): [str(i) for i in v] for k, v in qrels_train.items()}\n",
    "    qrels_validation = {\n",
    "        str(k): [str(i) for i in v] for k, v in qrels_validation.items()\n",
    "    }\n",
    "    qrels_test = {str(k): [str(i) for i in v] for k, v in qrels_test.items()}\n",
    "\n",
    "    # Load the dataset\n",
    "    train_dataset = EmbeddingQAFinetuneDataset(\n",
    "        corpus=corpus, queries=queries_train, relevant_docs=qrels_train\n",
    "    )\n",
    "    validation_dataset = EmbeddingQAFinetuneDataset(\n",
    "        corpus=corpus,\n",
    "        queries=queries_validation,\n",
    "        relevant_docs=qrels_validation,\n",
    "    )\n",
    "    test_dataset = EmbeddingQAFinetuneDataset(\n",
    "        corpus=corpus, queries=queries_test, relevant_docs=qrels_test\n",
    "    )\n",
    "\n",
    "    return train_dataset, validation_dataset, test_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the dataset and base embedding model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5\n",
      "Load pretrained SentenceTransformer: BAAI/bge-small-en-v1.5\n",
      "INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']\n",
      "2 prompts are loaded, with the keys: ['query', 'text']\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core.embeddings import resolve_embed_model\n",
    "\n",
    "train_dataset, val_dataset, test_dataset = load_hf_dataset(\"scifact\")\n",
    "base_embed_model = resolve_embed_model(\"local:BAAI/bge-small-en-v1.5\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If we take a peek at the dataset, we can see that its structured as\n",
    "- courpus: mapping of document ID to text\n",
    "- queries: mapping of query ID to query text\n",
    "- relevant_docs: a mapping of query ID to list of document IDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Depletion of nitric oxide is responsible for vasospasm.\n"
     ]
    }
   ],
   "source": [
    "print(val_dataset.queries[\"2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['552']\n"
     ]
    }
   ],
   "source": [
    "print(val_dataset.relevant_docs[\"2\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CONTEXT Delayed cerebral vasospasm causes permanent neurological deficits or death in at least 15% of patients following otherwise successful treatment for ruptured intracranial aneurysm. Decreased bioavailability of nitric oxide has been associated with the development of cerebral vasospasm. OBJECTIVE To determine whether infusions of nitrite will prevent delayed cerebral vasospasm. DESIGN, SETTING, AND SUBJECTS A total of 14 anesthetized cynomolgus monkeys had an autologous blood clot placed around the right middle cerebral artery. Cerebral arteriography was performed before clot placement and on days 7 and 14 to assess vasospasm. The study was conducted from August 2003 to February 2004. INTERVENTIONS A 90-mg sodium nitrite intravenous solution infused over 24 hours plus a 45-mg sodium nitrite bolus daily (n = 3); a 180-mg sodium nitrite intravenous solution infused over 24 hours (n = 3); or a control saline solution infusion (n = 8). Each was infused continuously for 14 days. MAIN OUTCOME MEASURES Nitrite, S-nitrosothiol, and methemoglobin levels in blood and cerebrospinal fluid and degree of arteriographic vasospasm. RESULTS In control monkeys, mean (SD) cerebrospinal fluid nitrite levels decreased from 3.1 (1.5) micromol/L to 0.4 (0.1) micromol/L at day 7 and to 0.4 (0.4) micromol/L at day 14 (P = .03). All 8 control monkeys developed significant vasospasm of the right middle cerebral artery, which was complicated by stroke and death in 1 animal. Sodium nitrite infusions increased the nitrite and methemoglobin levels (<2.1% of total hemoglobin) in the blood and cerebrospinal fluid without evoking systemic hypotension. Nitrite infusion prevented development of vasospasm (no animals developed significant vasospasm; mean [SD] reduction in right middle cerebral artery area on day 7 after subarachnoid hemorrhage of 8% [9%] in nitrite-treated monkeys vs 47% [5%] in saline-treated controls; P<.001). There was a negative correlation between the concentration of nitrite in cerebrospinal fluid and the degree of cerebral vasospasm (P<.001). Pharmacological effects of nitrite infusion were also associated with the formation of S-nitrosothiol in cerebrospinal fluid. There was no clinical or pathological evidence of nitrite toxicity. CONCLUSION Subacute sodium nitrite infusions prevented delayed cerebral vasospasm in a primate model of subarachnoid hemorrhage.\n"
     ]
    }
   ],
   "source": [
    "print(val_dataset.corpus[\"552\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Using your own Datasets\n",
    "\n",
    "As you can see, you can run this notebook on any dataset, as long as you have queries and a mapping to relevant documents! If you have documents but are missing a training set of queries checkout the our tools for generating a synthetic dataset ([1](https://docs.llamaindex.ai/en/stable/api_reference/evaluation/dataset_generation/)).\n",
    "\n",
    "If you wanted, you could also write your own dataset, or even use llama-index to create your own.\n",
    "\n",
    "Uncomment the code below and add your own files if you want to try it out."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code would generate your own dataset against your own custom data\n",
    "from llama_index.finetuning import generate_qa_embedding_pairs\n",
    "from llama_index.core import SimpleDirectoryReader\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from llama_index.core.evaluation import EmbeddingQAFinetuneDataset\n",
    "\n",
    "\n",
    "def load_corpus(files, verbose=False):\n",
    "    if verbose:\n",
    "        print(f\"Loading files {files}\")\n",
    "\n",
    "    reader = SimpleDirectoryReader(input_files=files)\n",
    "    docs = reader.load_data()\n",
    "    if verbose:\n",
    "        print(f\"Loaded {len(docs)} docs\")\n",
    "\n",
    "    parser = SentenceSplitter()\n",
    "    nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n",
    "\n",
    "    if verbose:\n",
    "        print(f\"Parsed {len(nodes)} nodes\")\n",
    "\n",
    "    return nodes\n",
    "\n",
    "\n",
    "# Load data\n",
    "# train_nodes = load_corpus([\"file1.pdf\", ...], verbose=True)\n",
    "# val_nodes = load_corpus([\"file2.pdf\", ...], verbose=True)\n",
    "\n",
    "# Generate pairs\n",
    "# train_dataset = generate_qa_embedding_pairs(train_nodes)\n",
    "# val_dataset = generate_qa_embedding_pairs(val_nodes)\n",
    "\n",
    "# [Optional] Save to disk\n",
    "# train_dataset.save_json(\"train_dataset.json\")\n",
    "# val_dataset.save_json(\"val_dataset.json\")\n",
    "\n",
    "# [Optional] Load\n",
    "# train_dataset = EmbeddingQAFinetuneDataset.from_json(\"train_dataset.json\")\n",
    "# val_dataset = EmbeddingQAFinetuneDataset.from_json(\"val_dataset.json\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "A common Information Retrieval metric to report during evaluation is NDCG@k."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Optional, Dict\n",
    "\n",
    "import torch\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from llama_index.core.schema import TextNode\n",
    "from llama_index.core.base.embeddings.base import BaseEmbedding\n",
    "from llama_index.core.base.base_retriever import BaseRetriever\n",
    "from llama_index.core import VectorStoreIndex\n",
    "\n",
    "\n",
    "def build_retriever(\n",
    "    corpus: Dict[str, str],\n",
    "    embed_model: BaseEmbedding | str,\n",
    "    corpus_embeddings: Optional[torch.Tensor] = None,\n",
    "    k: int = 10,\n",
    ") -> BaseRetriever:\n",
    "    nodes = []\n",
    "    for i, (id_, text) in enumerate(corpus.items()):\n",
    "        if corpus_embeddings is not None:\n",
    "            nodes.append(\n",
    "                TextNode(\n",
    "                    id_=id_, text=text, embedding=corpus_embeddings[i].tolist()\n",
    "                )\n",
    "            )\n",
    "        else:\n",
    "            nodes.append(TextNode(id_=id_, text=text))\n",
    "\n",
    "    index = VectorStoreIndex(\n",
    "        nodes=nodes,\n",
    "        embeddings=corpus_embeddings,\n",
    "        embed_model=embed_model,\n",
    "        show_progress=True,\n",
    "    )\n",
    "    return index.as_retriever(similarity_top_k=k)\n",
    "\n",
    "\n",
    "def ndcg_at_k(\n",
    "    dataset: EmbeddingQAFinetuneDataset, retriever: BaseRetriever, k: int = 10\n",
    "):\n",
    "    queries = dataset.queries\n",
    "    relevant_docs = dataset.relevant_docs\n",
    "    ndcg_scores = []\n",
    "    for query_id, query in tqdm(queries.items()):\n",
    "        retrieved_nodes = retriever.retrieve(query)\n",
    "        retrieved_ids = [node.node.node_id for node in retrieved_nodes]\n",
    "        expected_ids = relevant_docs[query_id]\n",
    "\n",
    "        # Calculate NDCG\n",
    "        ideal_dcg = np.sum(\n",
    "            [1 / np.log2(i + 2) for i in range(min(k, len(expected_ids)))]\n",
    "        )\n",
    "        rel_scores = np.zeros(k)\n",
    "        for j in range(min(k, len(retrieved_ids))):\n",
    "            if retrieved_ids[j] in expected_ids:\n",
    "                rel_scores[j] = 1\n",
    "        dcg = np.sum(\n",
    "            [rel_scores[i] / np.log2(i + 2) for i in range(len(rel_scores))]\n",
    "        )\n",
    "        ndcg = dcg / ideal_dcg if ideal_dcg > 0 else 0\n",
    "\n",
    "        ndcg_scores.append(ndcg)\n",
    "\n",
    "    mean_ndcg = np.mean(ndcg_scores)\n",
    "    return mean_ndcg"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the corpus embedding finetuning results\n",
    "Next we use, [NUDGE](https://www.arxiv.org/abs/2409.02343), the state of the art method for finetuning corpus embeddings to maximize the accuracy of k-NN retrieval. We then take our new corpus embeddings along with the original embedding model to build a retriever. NUDGE only finetunes the corpus embeddings and does not change any of the parameters in the base embedding model. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.experimental.nudge.base:Use pytorch device: cuda\n",
      "Use pytorch device: cuda\n"
     ]
    }
   ],
   "source": [
    "%%capture\n",
    "from llama_index.experimental import Nudge\n",
    "\n",
    "k = 10\n",
    "\n",
    "nudge = Nudge(\n",
    "    train_dataset=train_dataset,\n",
    "    val_dataset=val_dataset,\n",
    "    embed_model=base_embed_model,\n",
    "    use_nudge_n=True,\n",
    ")\n",
    "nudge.finetune()\n",
    "nudge_corpus_embeddings = nudge.get_finetuned_corpus_embeddings()\n",
    "nudge_retriever = build_retriever(\n",
    "    train_dataset.corpus, base_embed_model, nudge_corpus_embeddings, k=k\n",
    ")\n",
    "nudge_ndcg_test = ndcg_at_k(test_dataset, nudge_retriever, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the adapter finetuning results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.finetuning.embeddings.adapter:Use pytorch device: cuda\n",
      "Use pytorch device: cuda\n",
      "INFO:llama_index.embeddings.adapter.base:Use pytorch device: cuda\n",
      "Use pytorch device: cuda\n"
     ]
    }
   ],
   "source": [
    "%%capture\n",
    "from llama_index.finetuning import EmbeddingAdapterFinetuneEngine\n",
    "\n",
    "embedding_adapater_finetune_engine = EmbeddingAdapterFinetuneEngine(\n",
    "    train_dataset,\n",
    "    base_embed_model,\n",
    "    epochs=4,\n",
    "    batch_size=10,\n",
    ")\n",
    "embedding_adapater_finetune_engine.finetune()\n",
    "embedding_adapter_model = (\n",
    "    embedding_adapater_finetune_engine.get_finetuned_model()\n",
    ")\n",
    "ft_retriever = build_retriever(\n",
    "    train_dataset.corpus, embedding_adapter_model, k=k\n",
    ")\n",
    "ft_ndcg_test = ndcg_at_k(test_dataset, ft_retriever, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get the baseline results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "base_retriever = build_retriever(train_dataset.corpus, base_embed_model, k=k)\n",
    "bge_ndcg_test = ndcg_at_k(test_dataset, base_retriever, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display the results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "bge test - ndcg@10: 0.71\n",
      "adaptor finetune test - ndcg@10: 0.72\n",
      "NUDGE-N test - ndcg@10: 0.87\n"
     ]
    }
   ],
   "source": [
    "print(f\"bge test - ndcg@10: {bge_ndcg_test:.2f}\")\n",
    "print(f\"adaptor finetune test - ndcg@10: {ft_ndcg_test:.2f}\")\n",
    "print(f\"NUDGE-N test - ndcg@10: {nudge_ndcg_test:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inserting records into the dataset\n",
    "It's common to have your dataset expand over time. We will now insert and finetune the nfcorpus into the scifact example we've been working with. Usually you'd have to retrain on your entire dataset to avoid catastrophic forgetting. With NUDGE, you can easily expand your dataset iteratively by focusing only on the newest batch of data, without worrying about catastrophic forgetting. This only works when the new data being inserted does not conflict (e.g. new queries for old corpus or new corpus changes k-NN to old queries) with the existing dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "new_train_dataset, new_val_dataset, new_test_dataset = load_hf_dataset(\n",
    "    \"nfcorpus\"\n",
    ")\n",
    "\n",
    "# prepend \"nfcorpus-\" to the keys so they don't conflict with the scifact ids\n",
    "new_train_dataset.queries = {\n",
    "    f\"nfcorpus-{k}\": v for k, v in new_train_dataset.queries.items()\n",
    "}\n",
    "new_train_dataset.relevant_docs = {\n",
    "    f\"nfcorpus-{k}\": [f\"nfcorpus-{doc_id}\" for doc_id in v]\n",
    "    for k, v in new_train_dataset.relevant_docs.items()\n",
    "}\n",
    "new_train_dataset.corpus = {\n",
    "    f\"nfcorpus-{k}\": v for k, v in new_train_dataset.corpus.items()\n",
    "}\n",
    "\n",
    "new_val_dataset.queries = {\n",
    "    f\"nfcorpus-{k}\": v for k, v in new_val_dataset.queries.items()\n",
    "}\n",
    "new_val_dataset.relevant_docs = {\n",
    "    f\"nfcorpus-{k}\": [f\"nfcorpus-{doc_id}\" for doc_id in v]\n",
    "    for k, v in new_val_dataset.relevant_docs.items()\n",
    "}\n",
    "new_val_dataset.corpus = {\n",
    "    f\"nfcorpus-{k}\": v for k, v in new_val_dataset.corpus.items()\n",
    "}\n",
    "\n",
    "new_test_dataset.queries = {\n",
    "    f\"nfcorpus-{k}\": v for k, v in new_test_dataset.queries.items()\n",
    "}\n",
    "new_test_dataset.relevant_docs = {\n",
    "    f\"nfcorpus-{k}\": [f\"nfcorpus-{doc_id}\" for doc_id in v]\n",
    "    for k, v in new_test_dataset.relevant_docs.items()\n",
    "}\n",
    "new_test_dataset.corpus = {\n",
    "    f\"nfcorpus-{k}\": v for k, v in new_test_dataset.corpus.items()\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Finetune the new records"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "nudge.insert_data_and_finetune(\n",
    "    new_train_dataset_batch=new_train_dataset,\n",
    "    new_val_dataset_batch=new_val_dataset,\n",
    ")\n",
    "# get our corpus embeddings with the newly inserted and tuned records\n",
    "nudge_corpus_embeddings = nudge.get_finetuned_corpus_embeddings()\n",
    "# aggregate the corpus\n",
    "aggregated_corpus = {**train_dataset.corpus, **new_train_dataset.corpus}\n",
    "# build nudge retriever\n",
    "nudge_retriever = build_retriever(\n",
    "    aggregated_corpus, base_embed_model, nudge_corpus_embeddings, k=k\n",
    ")\n",
    "# get test results on nfcorpus\n",
    "nudge_ndcg_nfcorpus_test = ndcg_at_k(new_test_dataset, nudge_retriever, k)\n",
    "# get test results on scifact\n",
    "nudge_ndcg_scifact_test = ndcg_at_k(test_dataset, nudge_retriever, k)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Display the insertion results\n",
    "Check the results on our newly inserted nfcorpus records and verify that our scifact benchmark did not regress."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "NUDGE-N (aggregated) test on nfcorpus - ndcg@10: 0.44\n",
      "NUDGE-N (aggregated) test on scifact - ndcg@10: 0.85\n"
     ]
    }
   ],
   "source": [
    "print(\n",
    "    f\"NUDGE-N (aggregated) test on nfcorpus - ndcg@10: {nudge_ndcg_nfcorpus_test:.2f}\"\n",
    ")\n",
    "print(\n",
    "    f\"NUDGE-N (aggregated) test on scifact - ndcg@10: {nudge_ndcg_scifact_test:.2f}\"\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
